#!/opt/pwn.college/python

import argparse
import ctypes
import grp
import os
import shutil
import socket
import subprocess
import sys
import tempfile
import textwrap
import time

LINUX_REBOOT_CMD_POWER_OFF = 0x4321fedc
LINUX_REBOOT_CMD_RESTART = 0x1234567

libc = ctypes.CDLL("libc.so.6")


def error(msg):
    print(msg, file=sys.stderr)
    exit(1)


def initialize():
    os.setegid(os.geteuid())
    try:
        os.mkdir("/run/vm")
    except FileExistsError:
        pass


def vm_hostname():
    with open("/etc/hosts") as f:
        if "127.0.0.1\tvm" in f.read():
            return "vm"
        else:
            return "127.0.0.1"


def is_privileged():
    groups = set(grp.getgrgid(group).gr_name for group in os.getgroups())
    return os.getuid() == 0 or "sudo" in groups


def in_vm():
    return "init=/opt/pwn.college/vm/init" in open("/proc/cmdline").read().split()


def vm_power_off():
    assert libc.reboot(LINUX_REBOOT_CMD_POWER_OFF) != -1


def vm_reboot():
    assert libc.reboot(LINUX_REBOOT_CMD_RESTART) != -1


def extra_boot_flags():
    nokaslr = True
    if os.path.exists("/challenge/.kaslr"):
        nokaslr = False
    if args.nokaslr is not None:
        nokaslr = args.nokaslr

    nopti = False
    if os.path.exists("/challenge/.nopti"):
        nopti = True

    panic_on_oops = False
    if os.path.exists("/challenge/.panic_on_oops"):
        panic_on_oops = True

    result = []
    if nokaslr:
        result.append("nokaslr")

    if nopti:
        result.append("nopti")

    if panic_on_oops:
        result.append("oops=panic")
        result.append("panic_on_warn=1")

    return result


def execve(argv):
    os.seteuid(os.getuid())
    os.setegid(os.getgid())
    os.execve(argv[0], argv, os.environ)


def start():
    flags = " ".join(flag for flag in extra_boot_flags())

    bzImage = "/challenge/bzImage" if os.path.exists("/challenge/bzImage") else "/opt/linux/bzImage"

    kvm = os.path.exists("/dev/kvm")
    cpu = "host" if kvm else "qemu64"
    qemu_argv = [
        "/usr/bin/qemu-system-x86_64",
        "-kernel", bzImage,
        "-cpu", f"{cpu},smep,smap",
        "-fsdev", "local,id=rootfs,path=/,security_model=passthrough",
        "-device", "virtio-9p-pci,fsdev=rootfs,mount_tag=/dev/root",
        "-fsdev", "local,id=homefs,path=/home/hacker,security_model=passthrough",
        "-device", "virtio-9p-pci,fsdev=homefs,mount_tag=/home/hacker",
        "-device", "e1000,netdev=net0",
        "-netdev", "user,id=net0,hostfwd=tcp::22-:22",
        "-m", "2G",
        "-smp", "2" if kvm else "1",
        "-nographic",
        "-monitor", "none",
        "-append", f"rw rootfstype=9p rootflags=trans=virtio console=ttyS0 init=/opt/pwn.college/vm/init {flags}",
    ]
    if kvm:
        qemu_argv.append("-enable-kvm")

    if is_privileged():
        qemu_argv.append("-s")

    argv = [
        "/usr/sbin/start-stop-daemon",
        "--start",
        "--pidfile", "/run/vm/vm.pid",
        "--make-pidfile",
        "--background",
        "--no-close",
        "--quiet",
        "--oknodo",
        "--startas", qemu_argv[0],
        "--",
        *qemu_argv[1:]
    ]

    subprocess.run(argv,
                   stdin=subprocess.DEVNULL,
                   stdout=open("/run/vm/vm.log", "a"),
                   stderr=subprocess.STDOUT,
                   check=True)


def stop():
    argv = [
        "/usr/sbin/start-stop-daemon",
        "--stop",
        "--pidfile", "/run/vm/vm.pid",
        "--remove-pidfile",
        "--quiet",
        "--oknodo",
    ]

    subprocess.run(argv,
                   stdin=subprocess.DEVNULL,
                   stdout=open("/run/vm/vm.log", "a"),
                   stderr=subprocess.STDOUT,
                   check=True)

    os.unlink("/run/vm/vm.log")


def wait():
    for _ in range(50):
        try:
            connection = socket.create_connection((vm_hostname(), 22), timeout=30)
        except ConnectionRefusedError:
            pass
        else:
            data = connection.recv(3)
            connection.close()
            if data == b"SSH":
                break
        time.sleep(0.1)
    else:
        error("Error: could not connect to vm")


def connect():
    wait()
    execve(["/usr/bin/ssh", vm_hostname()])


def exec_(*args):
    wait()
    if sys.stdout.isatty():
        execve(["/usr/bin/ssh", "-t", vm_hostname(), "--", *args])
    else:
        execve(["/usr/bin/ssh", vm_hostname(), "--", *args])


def debug():
    try:
        socket.create_connection((vm_hostname(), 1234), timeout=30)
    except ConnectionRefusedError:
        error("Error: could not connect to debug")

    vmlinux = "/challenge/vmlinux" if os.path.exists("/challenge/vmlinux") else "/opt/linux/vmlinux"

    execve([
        "/usr/bin/gdb",
        "--ex", "target remote localhost:1234",
        vmlinux,
    ])


def build(path):
    os.setuid(os.geteuid())

    with open(path, "r") as f:
        src = f.read()

    with tempfile.TemporaryDirectory() as workdir:
        with open(f"{workdir}/debug.c", "w") as f:
            f.write(src)

        with open(f"{workdir}/Makefile", "w") as f:
            f.write(
                textwrap.dedent(
                    f"""
                    obj-m += debug.o

                    all:
                    \tmake -C /opt/linux/linux M={workdir} modules
                    clean:
                    \tmake -C /opt/linux/linux M={workdir} clean
                    """
                )
            )

        subprocess.run(["make", "-C", workdir], stdout=sys.stderr, check=True)

        shutil.copy(f"{workdir}/debug.ko", "/challenge/debug.ko")


def logs():
    argv = [
        "/usr/bin/tail",
        "-F",
        "-n+1",
        "/run/vm/vm.log",
    ]

    while True:
        subprocess.run(argv,
                       stdin=subprocess.DEVNULL,
                       stderr=subprocess.DEVNULL)

        time.sleep(0.1)


def main():
    global args

    initialize()

    parser = argparse.ArgumentParser()
    subparsers = parser.add_subparsers(dest="command", required=True)

    connect_parser = subparsers.add_parser("connect", help="connect to vm")

    exec_parser = subparsers.add_parser("exec", help="exec command in vm")
    exec_parser.add_argument("exec_command")
    exec_parser.add_argument("exec_command_args", nargs=argparse.REMAINDER)

    start_parser = subparsers.add_parser("start", help="start vm")

    stop_parser = subparsers.add_parser("stop", help="stop vm")

    restart_parser = subparsers.add_parser("restart", help="restart vm")

    logs_parser = subparsers.add_parser("logs", help="show vm logs")

    debug_parser = subparsers.add_parser("debug", help="privileged: debug vm")

    build_parser = subparsers.add_parser("build", help="privileged: build vm kernel module")
    build_parser.add_argument("build_path")

    parser.add_argument('--kaslr', dest='nokaslr', action='store_false', help="privileged: enable kaslr")
    parser.add_argument('--no-kaslr', dest='nokaslr', action='store_true', help="privileged: disable kaslr")
    parser.set_defaults(nokaslr=None)

    args = parser.parse_args()

    if not is_privileged():
        privilege_checks = {
            "modify boot flags": args.nokaslr is None,
            "debug": args.command != "debug",
            "build": args.command != "build",
        }
        for description, validated in privilege_checks.items():
            if not validated:
                error(f"Error: do not have permission to {description}")

    commands = {
        "connect": lambda: (start(), connect()),
        "exec": lambda: (start(), exec_(args.exec_command, *args.exec_command_args)),
        "start": lambda: (stop(), start()) if not in_vm() else vm_reboot(),
        "stop": lambda: stop() if not in_vm() else vm_power_off(),
        "restart": lambda: (stop(), start()) if not in_vm() else vm_reboot(),
        "debug": lambda: debug() if not in_vm() else error("Error: cannot debug from within vm"),
        "build": lambda: build(args.build_path),
        "logs": lambda: logs(),
    }

    try:
        commands[args.command]()
    except KeyboardInterrupt:
        pass


if __name__ == '__main__':
    main()
